import os
import pandas as pd
import ott
import warnings
import sys
warnings.filterwarnings("ignore")
sys.path.append('/ADBench')
from data_generator import DataGenerator
from myutils import Utils
import re
import gama
from gama import GamaClassifier

from sklearn.decomposition import FastICA
from pyod.models.iforest import IForest

ica = FastICA()

datagenerator = DataGenerator() # data generator
utils = Utils() # utils function
data_list = ['10_cover',
 '11_donors',
 '12_fault',
 '13_fraud',
 '14_glass',
 '15_Hepatitis',
 '16_http',
 '17_InternetAds',
 '18_Ionosphere',
 '19_landsat',
 '1_ALOI',
 '20_letter',
 '21_Lymphography',
 '22_magic.gamma',
 '23_mammography',
 '24_mnist',
 '25_musk',
 '26_optdigits',
 '27_PageBlocks',
 '28_pendigits',
 '29_Pima',
 '2_annthyroid',
 '30_satellite',
 '31_satimage-2',
 '32_shuttle',
 '33_skin',
 '34_smtp',
 '35_SpamBase',
 '36_speech',
 '37_Stamps',
 '38_thyroid',
 '39_vertebral',
 '3_backdoor',
 '40_vowels',
 '41_Waveform',
 '42_WBC',
 '43_WDBC',
 '44_Wilt',
 '45_wine',
 '46_WPBC',
 '47_yeast',
 '4_breastw',
 '5_campaign',
 '6_cardio',
 '7_Cardiotocography',
 '8_celeba',
 '9_census']


anchor_dataset = "31_satimage-2"


data_list.remove(anchor_dataset)
datagenerator.dataset = anchor_dataset # specify the dataset name
anchor_data = datagenerator.generator(la=0.1, realistic_synthetic_mode=None, noise_type=None)
x = ica.fit_transform(anchor_data['X_train'])
distances = dict()
for dataset in data_list:
    datagenerator.dataset = dataset # specify the dataset name
    data = datagenerator.generator(la=0.1, realistic_synthetic_mode=None, noise_type=None)
    y = ica.fit_transform(data['X_train'])

    geom_xx = ott.geometry.pointcloud.PointCloud(x)
    geom_yy = ott.geometry.pointcloud.PointCloud(y)

    # below `z` is there only to create n x m geometry
    prob = ott.core.quad_problems.QuadraticProblem(
        geom_xx, geom_yy
    )
    solver = ott.core.gromov_wasserstein.GromovWasserstein(rank=6)
    ot_gwlr = solver(prob)
    print(f"GWLR = {ot_gwlr.costs[ot_gwlr.costs > 0][-1]}")
    distances[dataset] = ot_gwlr.costs[ot_gwlr.costs > 0][-1]

most_similar_dataset = min(distances, key=distances.get)
meta_store = pd.read_csv('metadatastore.csv')
index_model = meta_store.loc[meta_store['dataset'] == most_similar_dataset]
model = list(index_model['model'])[0]
model_string = str(model)
model_string = re.sub(r'\n', '', model_string)
model = eval(model_string)
model = model['0']
model = model.fit(anchor_data['X_train'])
score = model.decision_function(anchor_data['X_test'])
result = utils.metric(y_true=anchor_data['y_test'], y_score=score)
# wandb.log({"model": str(model)})
# wandb.log({"aucroc_autood": result['aucroc']})
# wandb.log({"aucpr_autood": result['aucpr']})
baseline = IForest()
model = baseline.fit(anchor_data['X_train'])
score = baseline.decision_function(anchor_data['X_test'])
result = utils.metric(y_true=anchor_data['y_test'], y_score=score)